{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import numbers\n",
    "import os\n",
    "import os.path\n",
    "import numpy as np\n",
    "import struct\n",
    "import math\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "import matplotlib.pyplot as plt\n",
    "import h5py\n",
    "import json\n",
    "\n",
    "from util import som\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "%matplotlib qt5 \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# for train and val\n",
    "def som_saver_shrec2016(root, rows, cols, gpu_ids, output_root):\n",
    "    som_builder = som.SOM(rows, cols, 3, gpu_ids)\n",
    "    \n",
    "    folder_list = os.listdir(root)\n",
    "    for i, folder in enumerate(folder_list):\n",
    "        file_list = os.listdir(os.path.join(root, folder))\n",
    "        for j, file in enumerate(file_list):\n",
    "            if file[-3:] == 'txt':\n",
    "                data = np.loadtxt(os.path.join(root, folder, file))\n",
    "                pc_np = data[:, 0:3]\n",
    "                sn_np = data[:, 3:6]\n",
    "                \n",
    "                pc_np_sampled = pc_np[np.random.choice(pc_np.shape[0], 4096, replace=False), :]\n",
    "                pc = torch.from_numpy(pc_np_sampled.transpose().astype(np.float32)).cuda()  # 3xN tensor\n",
    "                som_builder.optimize(pc)\n",
    "                som_node_np = som_builder.node.cpu().numpy().transpose().astype(np.float32)  # node_numx3\n",
    "\n",
    "                npz_file = os.path.join(output_root, file[0:-4]+'.npz')\n",
    "                np.savez(npz_file, pc=pc_np, sn=sn_np, som_node=som_node_np)\n",
    "\n",
    "                if j%100==0:\n",
    "                    print('%s, %s' % (folder, file))\n",
    "\n",
    "#                 print(pc_np.shape)\n",
    "#                 print(som_node_np.shape)\n",
    "\n",
    "#                 x_np = pc_np\n",
    "#                 node_np = som_node_np\n",
    "#                 fig = plt.figure()\n",
    "#                 ax = Axes3D(fig)\n",
    "#                 ax.scatter(x_np[:,0].tolist(), x_np[:,1].tolist(), x_np[:,2].tolist(), s=1)\n",
    "#                 ax.scatter(node_np[:,0].tolist(), node_np[:,1].tolist(), node_np[:,2].tolist(), s=6, c='r')\n",
    "#                 plt.show()\n",
    "\n",
    "#                 if j>10:\n",
    "#                     break\n",
    "#         break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "rows, cols = 8, 8\n",
    "som_saver_shrec2016('/ssd/dataset/SHREC2016/obj_txt/train', rows, cols, True, '/ssd/dataset/SHREC2016/%dx%d/train'%(rows, cols))\n",
    "som_saver_shrec2016('/ssd/dataset/SHREC2016/obj_txt/val', rows, cols, True, '/ssd/dataset/SHREC2016/%dx%d/val'%(rows, cols))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# for test set\n",
    "def som_saver_shrec2016(root, rows, cols, gpu_ids, output_root):\n",
    "    som_builder = som.SOM(rows, cols, 3, gpu_ids)\n",
    "\n",
    "    file_list = os.listdir(root)\n",
    "    for j, file in enumerate(file_list):\n",
    "        if file[-3:] == 'txt':\n",
    "            data = np.loadtxt(os.path.join(root, file))\n",
    "            pc_np = data[:, 0:3]\n",
    "            sn_np = data[:, 3:6]\n",
    "\n",
    "            pc_np_sampled = pc_np[np.random.choice(pc_np.shape[0], 4096, replace=False), :]\n",
    "            pc = torch.from_numpy(pc_np_sampled.transpose().astype(np.float32)).cuda()  # 3xN tensor\n",
    "            som_builder.optimize(pc)\n",
    "            som_node_np = som_builder.node.cpu().numpy().transpose().astype(np.float32)  # node_numx3\n",
    "\n",
    "            npz_file = os.path.join(output_root, file[0:-4]+'.npz')\n",
    "            np.savez(npz_file, pc=pc_np, sn=sn_np, som_node=som_node_np)\n",
    "\n",
    "            if j%100==0:\n",
    "                print('%s' % (file))\n",
    "\n",
    "#                 print(pc_np.shape)\n",
    "#                 print(som_node_np.shape)\n",
    "\n",
    "#                 x_np = pc_np\n",
    "#                 node_np = som_node_np\n",
    "#                 fig = plt.figure()\n",
    "#                 ax = Axes3D(fig)\n",
    "#                 ax.scatter(x_np[:,0].tolist(), x_np[:,1].tolist(), x_np[:,2].tolist(), s=1)\n",
    "#                 ax.scatter(node_np[:,0].tolist(), node_np[:,1].tolist(), node_np[:,2].tolist(), s=6, c='r')\n",
    "#                 plt.show()\n",
    "\n",
    "#                 if j>10:\n",
    "#                     break\n",
    "#         break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rows, cols = 8, 8\n",
    "som_saver_shrec2016('/ssd/dataset/SHREC2016/obj_txt/test_allinone', rows, cols, True, '/ssd/dataset/SHREC2016/%dx%d/test'%(rows, cols))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "file = '/ssd/dataset/SHREC2016/8x8/train/model_013435.npz'\n",
    "data = np.load(file)\n",
    "pc_np = data['pc']\n",
    "sn_np = data['sn']\n",
    "som_node_np = data['som_node']\n",
    "\n",
    "print(pc_np)\n",
    "print(sn_np)\n",
    "print(som_node_np)\n",
    "\n",
    "x_np = pc_np\n",
    "node_np = som_node_np\n",
    "fig = plt.figure()\n",
    "ax = Axes3D(fig)\n",
    "ax.scatter(x_np[:,0].tolist(), x_np[:,1].tolist(), x_np[:,2].tolist(), s=1)\n",
    "ax.scatter(node_np[:,0].tolist(), node_np[:,1].tolist(), node_np[:,2].tolist(), s=6, c='r')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "file_list = os.listdir('/ssd/dataset/SHREC2016/8x8/test')\n",
    "file_list.sort()\n",
    "f = open('test.txt', 'w')\n",
    "for file in file_list:\n",
    "    f.write('%s\\n' % file[6:-4])\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
